import numpy as np
import qrcode
import json
from multiprocessing import Pool, Manager, freeze_support
import pandas as pd
import cv2

from glob import glob

from tqdm import tqdm


def generate_qrcode(data, image_name):
    # 创建一个二维码实例
    qr = qrcode.QRCode(
        version=1,
        error_correction=qrcode.constants.ERROR_CORRECT_L,
        box_size=8,
        border=2,
    )
    # 添加数据
    qr.add_data(data)
    qr.make(fit=True)
    # 生成二维码图像
    img = qr.make_image(fill_color="green", back_color="white")
    # 保存二维码图像
    img.save(image_name)

def generate_qrcode_x(data):
    # 创建一个二维码实例
    qr = qrcode.QRCode(
        version=1,
        error_correction=qrcode.constants.ERROR_CORRECT_L,
        box_size=8,
        border=2,
    )
    # 添加数据
    qr.add_data(data)
    qr.make(fit=True)
    # 生成二维码图像
    img = qr.make_image(fill_color="green", back_color="white")
    # 保存二维码图像
    return img
def gen_image(j, one, total_list):
    data = ""
    for i, o in enumerate(one[:-1]):
        print(j, i)
        label = one[i + 1]
        data += o
        generate_qrcode(data, "/home/aistudio/text_image/{}_{}.png".format(j, i))

        total_list.append([label, "/home/aistudio/text_image/{}_{}.png".format(j, i)])


def gen_text_to_image():
    with open('唐诗.json', 'r', encoding='utf-8') as f:
        dataset = json.load(f)

    two = [two_data[4].replace('\n', '') for two_data in dataset]
    two = [i for i in two if 32 <= len(i) <= 72]
    two = [i for i in two if len(i) == 32]

    total_list = Manager().list()

    pool = Pool(processes=7)

    for j, one in enumerate(two):
        pool.apply_async(gen_image, args=(j, one, total_list,))
    pool.close()
    pool.join()
    pd.to_pickle(list(total_list), "image_data_set.pkl")


import paddle


class VlmBlock(paddle.nn.Layer):
    def __init__(self, input_dim, output_dim, down_flag):
        super(VlmBlock, self).__init__()
        self.down_flag = down_flag
        self.one_layer = paddle.nn.Conv2D(input_dim, output_dim, 3, padding=1, bias_attr=False)
        self.two_layer = paddle.nn.Conv2D(input_dim, output_dim, 3, padding=1, bias_attr=False)
        self.three_layer = paddle.nn.Conv2D(input_dim, 2 * output_dim, 3, padding=1, bias_attr=False)
        if down_flag:
            self.down_layer = paddle.nn.MaxPool2D(2, 2)
        self.relu = paddle.nn.ReLU()

    def forward(self, x):
        x0 = self.one_layer(x)
        x1 = self.two_layer(x)
        x2 = self.three_layer(x)
        x = self.relu(paddle.concat([x0, x1], axis=1) + x2)
        if self.down_flag:
            x = self.down_layer(x)
        return x


class VLM(paddle.nn.Layer):
    def __init__(self, class_num):
        super(VLM, self).__init__()
        self.one_layer = VlmBlock(1, 3, True)
        self.two_layer = VlmBlock(6, 6, False)
        self.three_layer = VlmBlock(12, 12, True)
        self.four_layer = VlmBlock(24, 12, False)
        self.five_layer = VlmBlock(24, 12, True)
        self.six_layer = VlmBlock(24, 12, True)
        # 新增一个1x1卷积层来调整维度，如果需要的话

        self.out_layer = paddle.nn.Linear(1536, class_num)
        self.out_layer1 = paddle.nn.Linear(1536, class_num)

    def forward(self, x):
        x = self.one_layer(x)
        x = self.two_layer(x)
        x = self.three_layer(x)
        x = self.four_layer(x)
        x = self.five_layer(x)
        x = self.six_layer(x)
        x0 = self.out_layer(x.reshape([x.shape[0], -1]))
        x1 = self.out_layer1(x.reshape([x.shape[0], -1]))
        return x0, x1


class VlmLoss(paddle.nn.Layer):
    def __init__(self):
        super(VlmLoss, self).__init__()

    def forward(self, x, y, xx):
        loss = paddle.nn.functional.cross_entropy(x * (1 - paddle.nn.functional.softmax(xx, axis=-1)), y)
        return loss


def reshape_data():
    path = glob("E:/text_image/*")
    for i in tqdm(path):
        image = cv2.resize(cv2.imread(i, 0), (128, 128))
        cv2.imwrite(i, image)


if __name__ == '__main__':
    freeze_support()
    # gen_text_to_image()
    # resize 形状
    # reshape_data()

    path = pd.read_pickle("/home/aistudio/data/data263456/image_data_set.pkl")
    voc = sorted(set([i[0] for i in path]))


    vlm = VLM(class_num=len(voc))
    vlm.load_dict(paddle.load("/home/aistudio/vlm.pdparams"))
    vlm.eval()
    word="我"
    for _  in range(31):
        cv2_image = cv2.cvtColor(np.array(generate_qrcode_x(word)), cv2.COLOR_RGB2BGR)
        cv2_image=cv2.resize(cv2_image, (128, 128))
        cv2_image=cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
        cv2_image=cv2_image/255
        cv2_image=paddle.to_tensor(cv2_image).astype('float32').reshape([1, 1, 128, 128])
        vlm_out=vlm(cv2_image)
        vlm_out=vlm_out[0] * (1 - paddle.nn.functional.softmax(vlm_out[1], axis=-1))
        word+=voc[paddle.argmax(vlm_out,-1)]
        print(word)


